iT邦幫忙

2024 iThome 鐵人賽

DAY 4
0
Software Development

LSTM結合Yolo v8對於多隻斑馬魚行為分析系列 第 4

day 4 lstm對於斑馬魚的行為分析

  • 分享至 

  • xImage
  •  

第四天了我們可以上一些lstm斑馬魚的行為分析,但是是單隻斑馬魚的,以下是程式碼

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import LSTM, Dense

# 讀取數據
data = pd.read_csv('zebrafish_behavior.csv')

# 假設數據包含 'time', 'x', 'y' 列
data = data[['time', 'x', 'y']]

# 將時間轉換為數值
data['time'] = pd.to_datetime(data['time']).astype(int) / 10**9

# 預處理數據
scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(data)

# 創建數據集
def create_dataset(dataset, time_step=1):
    dataX, dataY = [], []
    for i in range(len(dataset)-time_step-1):
        a = dataset[i:(i+time_step), :]
        dataX.append(a)
        dataY.append(dataset[i + time_step, :])
    return np.array(dataX), np.array(dataY)

time_step = 10
X, y = create_dataset(scaled_data, time_step)

# 拆分訓練和測試數據
train_size = int(len(X) * 0.67)
test_size = len(X) - train_size
X_train, X_test = X[0:train_size], X[train_size:len(X)]
y_train, y_test = y[0:train_size], y[train_size:len(y)]

# 構建LSTM模型
model = Sequential()
model.add(LSTM(50, return_sequences=True, input_shape=(time_step, 3)))
model.add(LSTM(50, return_sequences=False))
model.add(Dense(25))
model.add(Dense(3))

model.compile(optimizer='adam', loss='mean_squared_error')

# 訓練模型
model.fit(X_train, y_train, batch_size=1, epochs=1)

# 預測
train_predict = model.predict(X_train)
test_predict = model.predict(X_test)

# 反轉縮放
train_predict = scaler.inverse_transform(train_predict)
test_predict = scaler.inverse_transform(test_predict)

# 視覺化結果
plt.figure(figsize=(14,5))
plt.plot(data['time'], data['x'], label='Actual X')
plt.plot(data['time'][:train_size], train_predict[:, 1], label='Train Predict X')
plt.plot(data['time'][train_size+time_step+1:], test_predict[:, 1], label='Test Predict X')
plt.xlabel('Time')
plt.ylabel('X Coordinate')
plt.legend()
plt.show()
data = pd.read_csv('zebrafish_behavior.csv')
data = data[['time', 'x', 'y']]

這部分程式碼從名為 zebrafish_behavior.csv 的 CSV 檔案中讀取數據,並選擇其中包含 timexy 列的數據。

時間數據處理

data['time'] = pd.to_datetime(data['time']).astype(int) / 10**9

將時間列轉換為數值格式,以便於數據處理。這裡將時間轉換為秒級別的 Unix 時間戳。

數據縮放

scaler = MinMaxScaler(feature_range=(0, 1))
scaled_data = scaler.fit_transform(data)

使用 MinMaxScaler 將數據縮放到 [0, 1] 範圍內,這有助於提升 LSTM 模型的訓練效果。

創建數據集

def create_dataset(dataset, time_step=1):
    dataX, dataY = [], []
    for i in range(len(dataset)-time_step-1):
        a = dataset[i:(i+time_step), :]
        dataX.append(a)
        dataY.append(dataset[i + time_step, :])
    return np.array(dataX), np.array(dataY)

time_step = 10
X, y = create_dataset(scaled_data, time_step)

定義一個函數來創建數據集。time_step 決定了每次訓練中使用多少時間步長。這裡選擇了 time_step=10,意味著每個輸入包含 10 個時間步長的數據。

拆分訓練和測試數據

train_size = int(len(X) * 0.67)
test_size = len(X) - train_size
X_train, X_test = X[0:train_size], X[train_size:len(X)]
y_train, y_test = y[0:train_size], y[train_size:len(y)]

將數據集拆分為訓練集和測試集,訓練集占總數據的 67%,測試集占 33%。

構建LSTM模型

model = Sequential()
model.add(LSTM(50, return_sequences=True, input_shape=(time_step, 3)))
model.add(LSTM(50, return_sequences=False))
model.add(Dense(25))
model.add(Dense(3))

model.compile(optimizer='adam', loss='mean_squared_error')

構建 LSTM 模型。使用 Sequential 模型添加兩層 LSTM 層和兩層全連接層(Dense 層)。第一層 LSTM 包含 50 個單元並返回完整序列,第二層 LSTM 包含 50 個單元但只返回最終的輸出。然後添加一個包含 25 個單元的全連接層,最後是包含 3 個單元的輸出層(對應 timexy)。

訓練模型

model.fit(X_train, y_train, batch_size=1, epochs=1)

訓練模型,這裡設置批次大小為 1,訓練 1 個 epoch。可以根據需求調整這些參數以提高模型性能。

預測

train_predict = model.predict(X_train)
test_predict = model.predict(X_test)

使用訓練好的模型對訓練集和測試集進行預測。

反轉縮放

train_predict = scaler.inverse_transform(train_predict)
test_predict = scaler.inverse_transform(test_predict)

將預測結果反轉縮放回原始數據範圍。

視覺化結果

plt.figure(figsize=(14,5))
plt.plot(data['time'], data['x'], label='Actual X')
plt.plot(data['time'][:train_size], train_predict[:, 1], label='Train Predict X')
plt.plot(data['time'][train_size+time_step+1:], test_predict[:, 1], label='Test Predict X')
plt.xlabel('Time')
plt.ylabel('X Coordinate')
plt.legend()
plt.show()

繪製圖形以視覺化實際數據與預測數據之間的比較。這裡繪製了實際的 x 坐標值,以及訓練集和測試集的預測 x 坐標值。

這段程式碼的目的是使用 LSTM 模型來分析和預測斑馬魚的行為,並將結果可視化以便於觀察和比較。整個過程包括數據讀取、預處理、模型構建與訓練、預測和結果可視化。


上一篇
day 3 Lstm天氣預測
下一篇
Day 5 yolo多隻斑馬魚行為分析
系列文
LSTM結合Yolo v8對於多隻斑馬魚行為分析29
圖片
  直播研討會
圖片
{{ item.channelVendor }} {{ item.webinarstarted }} |
{{ formatDate(item.duration) }}
直播中

尚未有邦友留言

立即登入留言